import logging
from copy import copy
from typing import Collection, Iterable, List, Sequence, Set

from common.credentials import Credentials, LMHash, NTHash, Password, Username
from infection_monkey.exploit.tools import (
    generate_brute_force_credentials,
    identity_type_filter,
    secret_type_filter,
)

logger = logging.getLogger(__name__)


def generate_powershell_credentials(
    credentials: Sequence[Credentials], running_from_windows: bool
) -> Sequence[Credentials]:
    # Don't modify the original sequence
    src_credentials: List[Credentials] = list(copy(credentials))
    usernames = {c.identity for c in src_credentials if isinstance(c.identity, Username)}

    if running_from_windows:
        try:
            local_username = _get_current_username()
            usernames.add(local_username)
            src_credentials.append(Credentials(identity=local_username, secret=None))
        except Exception as err:
            logger.warning(f"Failed to get current username: {err}")

    powershell_credentials: List[Credentials] = []

    # Note that if no secrets are provided, generate_brute_force_credentials() will return zero
    # results. This is why usernames are captured above.
    powershell_credentials.extend(
        generate_brute_force_credentials(
            src_credentials,
            identity_filter=identity_type_filter([Username]),
            secret_filter=secret_type_filter([Password, LMHash, NTHash]),
        )
    )

    powershell_credentials.extend(_generate_special_credentials(usernames, running_from_windows))
    powershell_credentials = list(
        _add_domains_to_usernames(powershell_credentials, running_from_windows)
    )

    return _remove_duplicate_credentials(powershell_credentials)


def _get_current_username() -> Username:
    import win32api

    username = win32api.GetUserNameEx(win32api.NameSamCompatible).split("\\")[1]
    return Username(username=username)


def _generate_special_credentials(
    usernames: Collection[Username], running_from_windows: bool
) -> Sequence[Credentials]:
    special_credentials: List[Credentials] = []
    if running_from_windows:
        special_credentials.extend(_generate_cached_credentials(usernames))

    special_credentials.extend(_generate_empty_password_credentials(usernames))

    return special_credentials


def _generate_cached_credentials(usernames: Iterable[Username]) -> Sequence[Credentials]:
    # On Windows, providing `None` for username and/or password will result in credentials "cached"
    # in the current user's session being used. This means that the correct username/password does
    # not always need to be supplied to the exploiter.
    cached_credentials = []

    cached_credentials.append(Credentials(identity=None, secret=None))
    for username in usernames:
        cached_credentials.append(Credentials(identity=username, secret=None))

    return cached_credentials


def _generate_empty_password_credentials(usernames: Iterable[Username]) -> Sequence[Credentials]:
    # Windows users can have "blank" passwords. Setting the password to an empty string will attempt
    # to authenticate as a user with a blank password.
    empty_password = Password(password="")
    return [Credentials(identity=u, secret=empty_password) for u in usernames]


def _add_domains_to_usernames(
    credentials: Sequence[Credentials], running_from_windows: bool
) -> Sequence[Credentials]:
    # This function preserves the order of username/secret pairs and inserts the
    # {domain}\{username}/secret pairs immediately next to the original username/secret pair.
    #
    # In some scenarios (such as multi-hop propagation), the domain is required to authenticate.
    # Sometimes the local domain (".\") is sufficient, but other times a real domain must be
    # supplied. This function tries both.

    domains = _get_domains(running_from_windows)
    credentials_with_domains = []
    for c in credentials:
        credentials_with_domains.append(c)

        if c.identity is None:
            continue

        if "\\" in c.identity.username:
            continue

        for d in domains:
            credentials_with_domains.append(
                Credentials(
                    identity=Username(username=f"{d}\\{c.identity.username}"), secret=c.secret
                )
            )

    return credentials_with_domains


def _get_domains(running_from_windows: bool) -> Set[str]:
    domains = {"."}

    if running_from_windows:
        try:
            domains.add(_get_domain())
        except Exception as err:
            logger.warning(f"Failed to get domain: {err}")

    return domains


def _get_domain() -> str:
    import win32api

    return win32api.GetUserNameEx(win32api.NameSamCompatible).split("\\")[0]


def _remove_duplicate_credentials(credentials: Sequence[Credentials]) -> Sequence[Credentials]:
    # Using a dict, not a set, to preserve order
    credentials_dict = dict.fromkeys(credentials, None)

    return list(credentials_dict.keys())
